{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 其他 C++ 打包函数的例子"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加法偏函数\n",
    "\n",
    "```{literalinclude} ../cpp/bind_add/src/tvm_ext.cc\n",
    ":language: c++\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "编译："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm -rf outputs/*\n",
      "g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\\<tvm/runtime/logging.h\\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build\n"
     ]
    }
   ],
   "source": [
    "!cd ../cpp/bind_add && make clean && make"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm_book.tvm_ext.libinfo import load_lib\n",
    "import tvm\n",
    "_LIB_EXT, _LIB_EXT_NAME = load_lib(name=\"libtvm_ext.so\", search_path=[\"../cpp/bind_add/outputs/libs\"])\n",
    "tvm._ffi._init_api(\"tvm_ext\", __name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tvm.runtime.packed_func.PackedFunc at 0x7f7e8240d8d0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bind_add"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add(a, b):\n",
    "    return a + b\n",
    "\n",
    "f = bind_add(add, 7)\n",
    "assert f(2) == 9"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## C++ 外部设备的例子"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{literalinclude} ../cpp/device_api/src/tvm_ext.cc\n",
    ":language: c++\n",
    "```\n",
    "编译："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm -rf outputs/*\n",
      "g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\\<tvm/runtime/logging.h\\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build\n"
     ]
    }
   ],
   "source": [
    "!cd ../cpp/device_api && make clean && make"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm_book.tvm_ext.libinfo import load_lib\n",
    "\n",
    "_LIB_EXT, _LIB_EXT_NAME = load_lib(name=\"libtvm_ext.so\", search_path=[\"../cpp/device_api/outputs/libs\"])\n",
    "tvm._ffi._init_api(\"tvm_ext\", __name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tvm import te\n",
    "n = 10\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "B = te.compute((n,), lambda *i: A(*i) + 1.0, name=\"B\")\n",
    "s = te.create_schedule(B.op)\n",
    "\n",
    "def check_llvm():\n",
    "    f = tvm.build(s, [A, B], tvm.target.Target(\"ext_dev\", \"llvm\"))\n",
    "    dev = tvm.ext_dev(0)\n",
    "    # launch the kernel.\n",
    "    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
    "    b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "    f(a, b)\n",
    "    np.testing.assert_allclose(b.numpy(), a.numpy() + 1)\n",
    "\n",
    "check_llvm()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 回调 C++ 端外部函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{literalinclude} ../cpp/extern_func/src/tvm_ext.cc\n",
    ":language: c++\n",
    "```\n",
    "编译："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm -rf outputs/*\n",
      "g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\\<tvm/runtime/logging.h\\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build\n"
     ]
    }
   ],
   "source": [
    "!cd ../cpp/extern_func && make clean && make"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm_book.tvm_ext.libinfo import load_lib\n",
    "\n",
    "_LIB_EXT, _LIB_EXT_NAME = load_lib(name=\"libtvm_ext.so\", search_path=[\"../cpp/extern_func/outputs/libs\"])\n",
    "tvm._ffi._init_api(\"tvm_ext\", __name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tvm import te\n",
    "n = 10\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "B = te.compute(\n",
    "    (n,), lambda *i: tvm.tir.call_extern(\"float32\", \"TVMTestAddOne\", A(*i)), name=\"B\"\n",
    ")\n",
    "s = te.create_schedule(B.op)\n",
    "\n",
    "def check_llvm():\n",
    "    f = tvm.build(s, [A, B], \"llvm\")\n",
    "    dev = tvm.cpu(0)\n",
    "    # launch the kernel.\n",
    "    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
    "    b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "    f(a, b)\n",
    "    np.testing.assert_allclose(b.numpy(), a.numpy() + 1)\n",
    "\n",
    "check_llvm()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 提取外部 C++ 函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{literalinclude} ../cpp/mini_runtime/src/tvm_ext.cc\n",
    ":language: c++\n",
    "```\n",
    "编译："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm -rf outputs/*\n",
      "g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\\<tvm/runtime/logging.h\\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build\n"
     ]
    }
   ],
   "source": [
    "!cd ../cpp/mini_runtime && make clean && make"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm_book.tvm_ext.libinfo import load_lib\n",
    "\n",
    "_LIB, _LIB_NAME = load_lib(name=\"libtvm_ext.so\", search_path=[\"../cpp/mini_runtime/outputs/libs\"])\n",
    "tvm._ffi._init_api(\"tvm_ext\", __name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "fdict = tvm._ffi.registry.extract_ext_funcs(_LIB.TVMExtDeclare)\n",
    "assert fdict[\"mul\"](3, 4) == 12"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
